import argparse
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

from model import NaiveClassifier, LabelNoiseClassifier
from data_creation import percentile_from_md, add_data_noise, add_categorical_noise, add_data_noise_tabular
from data_loaders import get_clean_loaders, get_noisy_loaders

from earlystoppingtool import EarlyStopping
from privacy_functions import calculate_xnoise_ynoise_from_epsilon, calculate_image_metrics
    
import numpy as np
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'       # Fix for weird OMP error

from argparse import Namespace
from torch import device, tensor
import random

# just for debugging purpose
from torchvision.utils import make_grid
sigmoid = torch.nn.Sigmoid()
    
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=128, type=int, help='training batch size')
parser.add_argument('--trainset_size', default=None, type=int,
                    help='Size of the training set (set to None for full training set)')

parser.add_argument('--epochs', default=500, type=int, help='number of epochs during training')
parser.add_argument('--lr', default=1e-4, type=float, help='Learning Rate.')
parser.add_argument('--log_interval', default=1000, type=int, help='how frequently to print loss')

parser.add_argument('--dropout_rate', default=0., type=float, help='Dropout rate.')

parser.add_argument('--pixel_level', action='store_true',
                    help='if true, reconstruct noisy latent. If false, use noisy latent')
                    
parser.add_argument('--noise_features_directly', action='store_true',
                    help='if true, add noise directly to the features and classify')
parser.add_argument('--task', default='MNIST', type=str, help='One of [MNIST, LendingClub], needed when specifying --noise_features_directly')
parser.add_argument('--noise_type', default='Laplace', type=str, help='One of [Gaussian, Laplace, Binary], needed when specifying --noise_features_directly')
parser.add_argument('--fraction', default=1., type=float, help='fraction to split pretraining and classification data, needed when specifying --noise_features_directly')


parser.add_argument('--epsilon', default=8.0, type=float, help='epsilon differential privacy levels')
parser.add_argument('--epsilon_split', default=0.7, type=float,
                    help='proportion of epsilon that x_noise contributes to')
                    
parser.add_argument('--grid_directory', default=None, type=str, help='save to dif directory for grid runs')
parser.add_argument('--param_dir', default='runs/Rep_MNIST_latentLaplace5_MDtrain5.0_std1.76',
                    type=str, help='directory where pre-trained params are stored.')
parser.add_argument('--posterior_dim', default=50, type=int, help='dimension of layers in posterior z net')

parser.add_argument('--patience', default=20, type=int, help='early stopping patience')
parser.add_argument('--validation', default=True, help='run with validation set, if False no early stopping is used')

parser.add_argument('--md', default=None, type=float,
                    help='Mahalanobis distance defining the clipping point. If None, use delta_f')
parser.add_argument('--dp_encoder', action='store_true', help='use a dp encoder, which is saved in param_dir')
parser.add_argument('--dp_decoder', action='store_true', help='use a dp decoder, which is saved in param_dir')
parser.add_argument('--central_epsilon', default=None, type=float, help='Central epsilon from pre-trained model')

parser.add_argument('--conv_classifier', action='store_true', help='if true use conv classifier, else use dense.')

parser.add_argument('--use_label_noise', action='store_true', help='classifier model that incorporates label noise.')

parser.add_argument('--synthetic_generation', action='store_true', help='use if generating synthetic data and label')
parser.add_argument('--novel_class', action='store_true', help='use if experimenting with the distributional shift')
parser.add_argument('--data_join_task', action='store_true', help='use if doing the data join task')

def load_opt_from_file(path):
    f = open(path+'/options.txt', 'r')
    data = f.read()
    f.close()
    return eval(data)


def save_opt_to_file(dic, path):
    f = open(path+'/classifier_options.txt','w')
    f.write(str(dic))
    f.close()


opt, unknown = parser.parse_known_args()


if os.path.exists(opt.param_dir+'/options.txt'):
    pretrain_opt = load_opt_from_file(opt.param_dir)
    opt.task = pretrain_opt.task
    opt.z_dim = pretrain_opt.rep_dim
    opt.noise_type = pretrain_opt.latent_distn
    opt.rep_dim = pretrain_opt.rep_dim
    if pretrain_opt.md and pretrain_opt.posterior_std:
        opt.posterior_epsilon = (2 * pretrain_opt.md) / (pretrain_opt.posterior_std / np.sqrt(2))
    if pretrain_opt.dp_encoder and opt.dp_decoder:
        raise ValueError('No dp decoder was trained in this rep model')
    if pretrain_opt.dp_decoder and opt.dp_encoder:
        raise ValueError('No dp encoder was trained in this rep model')
    opt.fraction = pretrain_opt.fraction #note frac always refers to pretraining fraction!!
elif opt.noise_features_directly:
    pass
else:
    raise ValueError("Path {} does not exist, and --noise_features_directly was not specified".format(opt.param_dir+'/options.txt'))
################################################################################


def train(opt, model, data_loader, epoch, iteration, optimizer, writer):
    model.train()
    losses = []
    for batch_idx, datapoint in enumerate(data_loader):
        if opt.data_join_task:
            data_split_noisy, data_split_clean, noisy_label = datapoint     # NB: no noise added to label
            data_split_clean = data_split_clean.to(opt.device)
            # normalise noisy data before concatenating, only continuous features in the first case
            if opt.noise_features_directly:
                data_split_noisy[:, :opt.n_continuous_features] = (data_split_noisy[:, :opt.n_continuous_features] - 
                                                    norm_mean[:opt.n_continuous_features]) / norm_std[:opt.n_continuous_features]
            else:
                data_split_noisy = (data_split_noisy - norm_mean) / norm_std
            
            noisy_data = torch.cat([data_split_clean, data_split_noisy], dim=1)
        else:
            noisy_data, noisy_label = datapoint

        current_iter = iteration + batch_idx + 1
        optimizer.zero_grad()

        noisy_label = noisy_label.float().unsqueeze(1)

        loss, _ = model.loss(noisy_data.to(opt.device), None,
                             noisy_label.to(opt.device), None)
        loss.backward()
        losses.append(loss.item())

        writer.add_scalar('train/Loss', loss.item(), current_iter)
        # writer.add_scalar('train/accuracy_true', accuracy, iter)

        optimizer.step()


        if batch_idx % opt.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}'.format(
                epoch, batch_idx * len(noisy_data), len(data_loader.dataset),
                       100. * batch_idx / len(data_loader),
                loss.item()))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, np.mean(losses)))

    return current_iter


def validate(model, data_loader, epoch, device, writer):
    model.eval()
    with torch.no_grad():
        losses = []
        accuracies_clean = []
        accuracies_noisy = []
        for batch_idx, datapoint in enumerate(data_loader):
            if opt.data_join_task:
                data_split_noisy, data_split_clean, noisy_label = datapoint     # NB: no noise added to label
                data_split_clean = data_split_clean.to(opt.device)
                # normalise noisy data before concatenating
                if opt.noise_features_directly:
                    data_split_noisy[:, :opt.n_continuous_features] = (data_split_noisy[:, :opt.n_continuous_features] - 
                                                        norm_mean[:opt.n_continuous_features]) / norm_std[:opt.n_continuous_features]
                else:
                    data_split_noisy = (data_split_noisy - norm_mean) / norm_std                
                noisy_data = torch.cat([data_split_clean, data_split_noisy], dim=1)
            else:
                noisy_data, noisy_label = datapoint
            if batch_idx > 100:
                print('Val loss evaluated on {} iters'.format(batch_idx - 1))
                break
            
            noisy_label = noisy_label.float().unsqueeze(1)    
            
            if opt.data_join_task:
                # label is clean
                loss, (accuracy_clean, accuracy_noisy) = model.loss(noisy_data.to(device), None, noisy_label.to(device),
                                                                noisy_label.to(device), eval_accuracy=True)                
            else:
                loss, (accuracy_clean, accuracy_noisy) = model.loss(noisy_data.to(device), None, noisy_label.to(device),
                                                                None, eval_accuracy=False)
            losses.append(loss.item())
            accuracies_clean.append(accuracy_clean)
            accuracies_noisy.append(accuracy_noisy)

    writer.add_scalar('val/Loss', np.mean(losses), epoch)
    writer.add_scalar('val/accuracy_clean', np.mean(accuracies_clean), epoch)
    writer.add_scalar('val/accuracy_noisy', np.mean(accuracies_noisy), epoch)

    myfile = open(out_dir+'/true_val_accuracies.csv','a+')
    if opt.synthetic_generation:
        myfile.write(str(np.mean(accuracies_clean))+'\n')
    else:
        myfile.write(str(np.mean(accuracies_clean))+', ')
        myfile.write(str(np.mean(accuracies_noisy))+'\n')
    myfile.close()
    
    print('\nEpoch: {}\tVal loss: {:.6f}\n\n'.format(epoch, np.mean(losses)))
    return np.mean(losses)
    

def test(model, encoder, decoder, data_loader, x_noise, y_noise, epoch, entire_testset, device, writer, noise_type):
    model.eval()
    with torch.no_grad():
        losses = []
        accuracies_clean = []
        accuracies_noisy = []

        for batch_idx, datapoint in enumerate(data_loader):
            if opt.data_join_task:
                data, data_join_clean, label = datapoint
                data_join_clean = data_join_clean.to(device)
            else:
                data, label = datapoint
            if batch_idx > 50 and not entire_testset:
                print('Test loss evaluated on {} iters'.format(batch_idx - 1))
                break

            if opt.synthetic_generation:
                clean_data = data.to(device)
                noisy_data = noisy_label = None
            elif opt.noise_features_directly:
                clean_data = data.to(device)
                if opt.tabular:
                    noisy_data = add_data_noise_tabular(clean_data, opt)
                else:
                    noisy_data = add_data_noise(clean_data, x_noise, noise_type)
                noisy_label = add_categorical_noise(label.to(device),
                                                    opt.n_categories,
                                                    y_noise).float().unsqueeze(1)

            else:
                if opt.pixel_level:
                    latents = encoder.get_data_representatation(data.to(device), data_loader=False, clip=opt.md)
                    clean_data = data.to(device)
                    noisy_data = decoder.get_data_reconstruction(latents.to(device), x_noise, clip=opt.md)
                else:
                    clean_data = encoder.get_data_representatation(data.to(device), data_loader=False, clip=opt.md)
                    noisy_data = add_data_noise(clean_data, x_noise, noise_type)
                noisy_label = add_categorical_noise(label.to(device),
                                                    opt.n_categories,
                                                    y_noise).float().unsqueeze(1)

            if opt.data_join_task:
                # normalise noisy data before concatenating
                if opt.noise_features_directly:
                    noisy_data[:, :opt.n_continuous_features] = (noisy_data[:, :opt.n_continuous_features] - 
                                                        norm_mean[:opt.n_continuous_features]) / norm_std[:opt.n_continuous_features]
                else:
                    noisy_data = (noisy_data - norm_mean) / norm_std                
                noisy_data = torch.cat([data_join_clean, noisy_data], dim=1)
                    
            loss, (accuracy_clean, accuracy_noisy) = model.loss(noisy_data, clean_data, noisy_label,
                                                                label.to(device).float().unsqueeze(1),
                                                                eval_accuracy=True)

            if not opt.synthetic_generation:
                losses.append(loss.item())
                accuracies_noisy.append(accuracy_noisy)
            accuracies_clean.append(accuracy_clean)

    # just for debugging purpose
    if opt.pixel_level and (not opt.synthetic_generation and not opt.data_join_task):
        x_with_mu = torch.cat((clean_data.to(device), noisy_data))
        if opt.noise_features_directly:
            writer.add_image(tag='noisy_images',
                             img_tensor=make_grid(sigmoid(x_with_mu)),
                             global_step=0)
        else:
            writer.add_image(tag='reconstructions',
                             img_tensor=make_grid(sigmoid(x_with_mu)),
                             global_step=0)

    writer.add_scalar('test/accuracy_clean', np.mean(accuracies_clean), epoch)

    if opt.synthetic_generation:
        print('\nEpoch: {}\tTest clean accuracy: {:.6f}'.format(epoch, np.mean(accuracies_clean)))
    else:
        writer.add_scalar('test/Loss', np.mean(losses), epoch)
        writer.add_scalar('test/accuracy_noisy', np.mean(accuracies_noisy), epoch)
        print('\nEpoch: {}\tTest loss: {:.6f}\tTest clean accuracy: {:.6f}, \tTest noisy '
              'accuracy: {:.6f}\n\n'.format(epoch, np.mean(losses),
                                            np.mean(accuracies_clean),
                                            np.mean(accuracies_noisy)))
                                                                             
    myfile = open(out_dir+'/true_test_accuracies.csv','a+')
    if opt.synthetic_generation:
        myfile.write(str(np.mean(accuracies_clean))+'\n')
    else:
        myfile.write(str(np.mean(accuracies_clean))+', ')
        myfile.write(str(np.mean(accuracies_noisy))+'\n')
    myfile.close()

################################################################################


if __name__ == '__main__':
    if opt.noise_features_directly or opt.synthetic_generation:
        opt.pixel_level = True
    if opt.data_join_task:
        if opt.noise_features_directly:
            opt.pixel_level = True
        else:
            opt.pixel_level = False
        opt.epsilon_split = 1.0

    level = 'Pixel' if opt.pixel_level else 'Latent'
    idx = opt.param_dir.rfind('VAE')
    if opt.grid_directory is not None:
        out_dir = './grid_runs/'+opt.grid_directory
        if opt.noise_features_directly:
            if opt.use_label_noise:
                out_dir += 'LabelNoise{}'.format(level)
            else:
                out_dir += 'Naive{}'.format(level)
        else:
            if opt.use_label_noise:
                out_dir += 'LabelNoise{}_{}'.format(level, opt.param_dir[idx:])
            else:
                out_dir += 'Naive{}_{}'.format(level, opt.param_dir[idx:])
        # this should set the seed to the trial number in the grid search
        # assuming the directory is in the form of 'trialx/''
        try:
            seed = int(opt.grid_directory[-2])
            print(f'The seed is {seed}')
            np.random.seed(seed)
            torch.manual_seed(seed)
            if opt.noise_features_directly:
                torch.cuda.manual_seed(seed)
                random.seed(seed)
                torch.backends.cudnn.enabled=False
                torch.backends.cudnn.deterministic=True     
        except:
            print('The random seed may have not been assigned properly')
            np.random.seed(4)
            torch.manual_seed(4)
    else:
        out_dir = './runs/'
        if opt.noise_features_directly:
            out_dir += 'direct_pixel/PixelDirectly'
        if opt.synthetic_generation:
            out_dir += 'Synth'
        if opt.novel_class:
            out_dir += 'dShift'
        if opt.use_label_noise:
            if opt.noise_features_directly:
                out_dir += 'LabelNoise{}'.format(level)
            else:
                if idx == -1:
                    out_dir += 'LabelNoise{}'.format(level)
                else:
                    out_dir += 'LabelNoise{}_{}'.format(level, opt.param_dir[idx:])
        else:
            if idx == -1:
                out_dir += 'Naive{}'.format(level)
            else:
                # out_dir = './runs/NoisyClassifier_{:%m%d%H%M}'.format(datetime.now())
                out_dir += 'Naive{}_{}'.format(level, opt.param_dir[idx:])
        np.random.seed(4)
        torch.manual_seed(4)

    out_dir += '_task{}'.format(opt.task)
    out_dir += '_centraleps{}'.format(opt.central_epsilon)
    out_dir += '_localeps{}split{}'.format(opt.epsilon, opt.epsilon_split)
    
    if not opt.noise_features_directly:
        if opt.md is not None:
            out_dir += '_MD{}'.format(opt.md)
        if '_weakdec' in opt.param_dir:
            out_dir += '_weakdec'
            
    out_dir += '_{:%m%d%H%M}'.format(datetime.now())

    os.makedirs(out_dir, exist_ok=True)
    writer = SummaryWriter(out_dir)

    use_cuda = torch.cuda.is_available()
    opt.device = torch.device("cuda" if use_cuda else "cpu")

    clean_train_loader, clean_val_loader, clean_test_loader, recon_dataset, opt.image_dim,  opt.clean_join_data_dim, opt.tabular, \
    opt.n_continuous_features, opt.ncat_of_cat_features, opt.n_categories, opt.maximum_pixel_difference \
        = get_clean_loaders(opt, writer=writer, stage='classification')


    if opt.noise_features_directly:
        opt.encoder, opt.decoder = None, None
    else:
        if opt.dp_encoder:
            opt.encoder = torch.load(os.path.join(opt.param_dir, 'dp_rep_model_eps{}.pth'.format(opt.central_epsilon)))
        else:
            opt.encoder = torch.load(os.path.join(opt.param_dir, 'rep_model.pth'))
        if opt.dp_decoder:
            opt.decoder = torch.load(os.path.join(opt.param_dir, 'dp_rep_model_eps{}.pth'.format(opt.central_epsilon)))
        else:
            opt.decoder = torch.load(os.path.join(opt.param_dir, 'rep_model.pth'))
    
        opt.encoder.eval()
        opt.decoder.eval()
        
    if not (opt.synthetic_generation or opt.epsilon == float('inf')):
        if opt.noise_features_directly:
            if opt.tabular:
                opt.delta_f = opt.maximum_pixel_difference
            else:
                opt.delta_f = np.prod(opt.image_dim)*opt.maximum_pixel_difference
        elif opt.md is not None:
            opt.prior_clip = percentile_from_md(opt.noise_type, opt.z_dim, opt.md)
            opt.delta_f = 2 * opt.md
        elif opt.epsilon > 9999:
            print('Since no noise is being added, delta_f is set to None')
            opt.deltaf = None
        else:
            raise NotImplementedError("")

    print("Calculating noise levels...")
    if opt.synthetic_generation:
        opt.x_noise, opt.y_noise = None, None
    else:
        if opt.epsilon == 0.:
            opt.x_noise, opt.y_noise = 30., 0.9
        elif opt.epsilon > 9999.:
            if opt.noise_features_directly and opt.tabular:
                total_number_of_features = len(opt.delta_f)+len(opt.ncat_of_cat_features)
                opt.x_noise, opt.y_noise = torch.zeros((total_number_of_features), device=opt.device), 0.
            else:    
                opt.x_noise, opt.y_noise = 0., 0.
        else:
            opt.x_noise, opt.y_noise = calculate_xnoise_ynoise_from_epsilon(opt)

    print("Creating noisy dataset...")
    noisy_train_loader, noisy_val_loader, norm_mean, norm_std = get_noisy_loaders(opt)
    print('Noisy dataset created.')

    if opt.validation:
        early_stopping = EarlyStopping(patience=opt.patience, verbose=True, outdir=out_dir)
        
    print("\n", "=" * 50)
    print("Hyperparameters")
    for x, y in vars(opt).items():
        print(x, y)
    print("=" * 50, "\n")
    
    if opt.pixel_level and not opt.synthetic_generation:
        if opt.tabular:
            print("no rec metrics for tabular data")
        else:
            print("Calculating image reconstruction metrics...")
            image_metrics = calculate_image_metrics(clean_test_loader, opt.encoder, opt.decoder, opt.x_noise, opt.md, opt.device)
            f = open(out_dir+'/image_metrics.txt','w')
            f.write(str(image_metrics))
            f.close()
            print('metrics calculated on test set, privatised with noise {}'.format(opt.x_noise))

    # TRAIN NOISY CLASSIFICATION MODEL
    if opt.use_label_noise:
        classifier_model = LabelNoiseClassifier(opt).to(opt.device)
    else:
        classifier_model = NaiveClassifier(opt).to(opt.device)
    classifier_optimizer = optim.Adam(classifier_model.network.parameters(), lr=opt.lr)

    iters = 0
    for epoch in range(0, opt.epochs):
        iters = train(opt, classifier_model, noisy_train_loader, epoch + 1, iters,
                      classifier_optimizer, writer)

        if opt.validation:
            val_loss = validate(classifier_model, noisy_val_loader, epoch + 1, opt.device, writer)

        entire_testset = (epoch == opt.epochs - 1)
        test(classifier_model, opt.encoder, opt.decoder, clean_test_loader, opt.x_noise, opt.y_noise, epoch + 1,
             entire_testset, opt.device, writer, opt.noise_type)

        print('Saving params...')
        torch.save(classifier_model.state_dict(),
                   os.path.join(out_dir, 'noisy_classifier_params.pth'))
        print('Params saved.')

        if opt.validation:
            early_stopping(val_loss, classifier_model)
            
            if early_stopping.early_stop:
                print("Patience limit reached")
                break
    
    save_opt_to_file(opt, out_dir)

